Skip to content

Conversation

hsjts0u
Copy link
Contributor

@hsjts0u hsjts0u commented Sep 4, 2025

Add default args for _aten_conv2d, which would otherwise fail in the following code snippet

import torch
from torch.export import export_for_training
import torchax
from torchax import interop
from torch.utils import _pytree as pytree
import jax
from torchax.ops import mappings

class Simple(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=4, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        return x
    

model = Simple()

exported = export_for_training(model, (torch.randn(1, 3, 224, 224),))

def make_shape_struct(x):
    return jax.ShapeDtypeStruct(x.shape, mappings.t2j_dtype(x.dtype))


def map_nth(v, i):
    def f(t):
        if isinstance(t, torch.Tensor):
            return t[i : i + 1]
        return t

    return pytree.tree_map(f, v)


env = torchax.default_env()
with env:
    model = exported.module().to("jax")

    def func_to_export(x):
        # hard code weights in model
        return model(x)

    example_inputs_jax = pytree.tree_map_only(
        torch.Tensor, lambda x: x.to("jax"), map_nth(exported.example_inputs, 0)
    )

    res = jax.jit(interop.jax_view(func_to_export)).lower(*example_inputs_jax[0])

# TypeError: _aten_conv2d() missing 5 required positional arguments: 'bias', 'stride', 'padding', 'dilation', and 'groups'

cc @qihqi

@hsjts0u hsjts0u changed the title Add default args for _aten_con2d Add default args for _aten_conv2d Sep 4, 2025
@qihqi qihqi enabled auto-merge (squash) September 18, 2025 00:23
@qihqi
Copy link
Collaborator

qihqi commented Sep 18, 2025

thanks!

@qihqi
Copy link
Collaborator

qihqi commented Sep 19, 2025

Hi @hsjts0u would you rebase to latest HEAD? it should fix the CI issue.

@hsjts0u
Copy link
Contributor Author

hsjts0u commented Sep 21, 2025

Done

@qihqi qihqi disabled auto-merge September 29, 2025 22:08
@qihqi
Copy link
Collaborator

qihqi commented Sep 29, 2025

I'll merge given that CI still need some work(bazel cache) and is independent of this chagne. This change itself is verified to be good because of torchax CIs.

@qihqi qihqi merged commit 03d4dc0 into pytorch:master Sep 29, 2025
12 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants